load("@rules_cuda//cuda:defs.bzl", "cuda_library")
load(
    ":flag_validation_test.bzl",
    "cuda_library_c_dbg_flag_test",
    "cuda_library_c_dbg_static_msvcrt_flag_test",
    "cuda_library_c_fastbuild_flag_test",
    "cuda_library_c_fastbuild_static_msvcrt_flag_test",
    "cuda_library_c_opt_flag_test",
    "cuda_library_c_opt_static_msvcrt_flag_test",
    "cuda_library_compute60_flag_test",
    "cuda_library_compute60_sm61_flag_test",
    "cuda_library_compute61_sm61_flag_test",
    "cuda_library_flag_test",
    "cuda_library_sm61_flag_test",
    "cuda_library_sm90a_flag_test",
    "cuda_library_sm90a_sm90_flag_test",
    "num_actions_test",
)

num_actions_test(
    name = "cuda_library_num_actions_test",
    target_under_test = "@rules_cuda_examples//basic:kernel",
    # 2 compiling, 2 lib, non-pic and pic respectively
    # num_actions = 4,
)

cuda_library_flag_test(
    name = "cuda_library_common_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-x cu",
        "-o",
        "-ccbin",
    ],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

# https://docs.bazel.build/versions/main/user-manual.html#flag--compilation_mode
#
# NOTE: -O specify optimization level for **host** code only, see
# https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#optimize-level-o
cuda_library_c_dbg_flag_test(
    name = "cuda_library_c_dbg_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        # gcc default to -O0 so bazel does not explicitly set the flag. Whereas nvcc defaults to -O3.
        "-O0",
        "-g",
    ],
    target_compatible_with = ["@platforms//os:linux"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_c_fastbuild_flag_test(
    name = "cuda_library_c_fastbuild_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        # gcc default to -O0 so bazel does not explicitly set the flag. Whereas nvcc defaults to -O3.
        "-O0",
        "--generate-line-info",
        "-g1",
        # There is no -gmlt option (not merged) as claimed in bazel user manual
        # https://codereview.appspot.com/4440072
        # https://gcc.gnu.org/legacy-ml/gcc-patches/2011-04/msg02075.html
        # "-Xcompiler -gmlt",
    ],
    not_contain_flags = ["-DNDEBUG"],
    target_compatible_with = ["@platforms//os:linux"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_c_opt_flag_test(
    name = "cuda_library_c_opt_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-g0",
        "-O2",
        "-DNDEBUG",
    ],
    not_contain_flags = ["-g"],
    target_compatible_with = ["@platforms//os:linux"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

# Windows dynamic msvcrt and static msvcrt test
cuda_library_c_dbg_flag_test(
    name = "cuda_library_c_dbg_dynamic_msvcrt_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-O0",
        "/Z7",
        "/MDd",
    ],
    not_contain_flags = [
        "-DNDEBUG",
        "/MD",
        "/MT",
        "/MTd",
    ],
    target_compatible_with = ["@platforms//os:windows"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_c_fastbuild_flag_test(
    name = "cuda_library_c_fastbuild_dynamic_msvcrt_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-O0",
        "/Z7",
        "/MD",
    ],
    not_contain_flags = [
        "-DNDEBUG",
        "/MDd",
        "/MT",
        "/MTd",
    ],
    target_compatible_with = ["@platforms//os:windows"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_c_opt_flag_test(
    name = "cuda_library_c_opt_dynamic_msvcrt_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-O2",
        "/Oy-",
        "/MD",
        "-DNDEBUG",
    ],
    not_contain_flags = [
        "/MDd",
        "/MT",
        "/MTd",
    ],
    target_compatible_with = ["@platforms//os:windows"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_c_dbg_static_msvcrt_flag_test(
    name = "cuda_library_c_dbg_static_msvcrt_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-O0",
        "/Z7",
        "/MTd",
    ],
    not_contain_flags = [
        "-DNDEBUG",
        "/MD",
        "/MDd",
        "/MT",
    ],
    target_compatible_with = ["@platforms//os:windows"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_c_fastbuild_static_msvcrt_flag_test(
    name = "cuda_library_c_fastbuild_static_msvcrt_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-O0",
        "/Z7",
        "/MT",
    ],
    not_contain_flags = [
        "-DNDEBUG",
        "/MD",
        "/MDd",
        "/MTd",
    ],
    target_compatible_with = ["@platforms//os:windows"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_c_opt_static_msvcrt_flag_test(
    name = "cuda_library_c_opt_dynamic_static_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-O2",
        "/Oy-",
        "/MT",
        "-DNDEBUG",
    ],
    not_contain_flags = [
        "/MD",
        "/MDd",
        "/MTd",
    ],
    target_compatible_with = ["@platforms//os:windows"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

###########################################
# Test arch if no cuda_arch is specified
# TODO: switch default to use -arch=native
###########################################
cuda_library_flag_test(
    name = "cuda_library_default_arch_flag_test",
    action_mnemonic = "CudaCompile",
    ## TODO: enable following if nvcc version >= 11.6
    # contain_flags = [
    #     "-arch=native",
    # ],
    not_contain_flags = [
        "-gencode",
    ],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library(
    name = "examples_basic_kernel_with_flags",
    srcs = ["@rules_cuda_examples//basic:kernel.cu"],
    hdrs = ["@rules_cuda_examples//basic:kernel.h"],
    copts = ["-G"],
    host_copts = ["-O1"],
    rdc = 1,
    tags = ["manual"],
)

cuda_library_flag_test(
    name = "cuda_library_copts_and_host_copts_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-G",
        "-Xcompiler -O1",
    ],
    not_contain_flags = [
        "-Xcompiler -G",
    ],
    target_under_test = "examples_basic_kernel_with_flags",
)

cuda_library_flag_test(
    name = "cuda_library_dlink_copts_and_host_copts_flag_test",
    action_mnemonic = "CudaDeviceLink",
    contain_flags = [
        "-Xcompiler -O1",
    ],
    not_contain_flags = [
        "-G",
    ],
    target_under_test = "examples_basic_kernel_with_flags",
)

# -fPIC is only meaningful for linux targets
cuda_library_flag_test(
    name = "cuda_library_nonpic_compile_flag_test",
    action_mnemonic = "CudaCompile",
    not_contain_flags = [
        "-fPIC",
    ],
    output_name = "kernel.o",
    target_compatible_with = ["@platforms//os:linux"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_flag_test(
    name = "cuda_library_nonpic_dlink_flag_test",
    action_mnemonic = "CudaCompile",
    not_contain_flags = [
        "-fPIC",
    ],
    output_name = "kernel.o",
    target_compatible_with = ["@platforms//os:linux"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_flag_test(
    name = "cuda_library_pic_compile_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-Xcompiler -fPIC",
    ],
    output_name = "kernel.pic.o",
    target_compatible_with = ["@platforms//os:linux"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_flag_test(
    name = "cuda_library_pic_dlink_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-Xcompiler -fPIC",
    ],
    output_name = "kernel.pic.o",
    target_compatible_with = ["@platforms//os:linux"],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

######################
# Tests with -gencode
######################
cuda_library_sm61_flag_test(
    name = "cuda_library_sm61_arch_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-gencode arch=compute_61,code=sm_61",
    ],
    target_under_test = "@rules_cuda_examples//basic:kernel",
    # target_under_test = "//cuda:archs",
)

cuda_library_compute60_flag_test(
    name = "cuda_library_compute60_arch_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-gencode arch=compute_60,code=compute_60",
    ],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_compute60_sm61_flag_test(
    name = "cuda_library_compute60_sm61_arch_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-gencode arch=compute_60,code=compute_60",
        "-gencode arch=compute_60,code=sm_61",
    ],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_compute61_sm61_flag_test(
    name = "cuda_library_compute61_sm61_arch_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-gencode arch=compute_61,code=compute_61",
        "-gencode arch=compute_61,code=sm_61",
    ],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_sm90a_flag_test(
    name = "cuda_library_sm90a_arch_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-gencode arch=compute_90a,code=sm_90a",
    ],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)

cuda_library_sm90a_sm90_flag_test(
    name = "cuda_library_sm90a_sm90_arch_flag_test",
    action_mnemonic = "CudaCompile",
    contain_flags = [
        "-gencode arch=compute_90,code=sm_90",
        "-gencode arch=compute_90,code=sm_90a",
    ],
    target_under_test = "@rules_cuda_examples//basic:kernel",
)
